import logging
import os
from typing import Dict, List, Optional, Union
import torchaudio
import torch
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer

from wenet.transformer.encoder import TransformerEncoder, TransformerEncoder2
from wenet.llm_asr.utils4llmasr import *
from gxl_ai_utils.utils import utils_file

from wenet.llm_asr.downsampler import get_downsampler, LyzConv1dSubsampling
from wenet.transformer.swish import New_gelu4npu
from wenet.utils.mask import make_pad_mask
import torch.nn.functional as F
import math
import gc


# import torch_npu
# from torch_npu.contrib import transfer_to_npu

# from msprobe.pytorch import seed_all,PrecisionDebugger

class SpeechBigHead(nn.Module):
    def __init__(self,
                 encoder,
                 linear_head,
                 ):
        super().__init__()
        self.encoder = encoder
        self.linear_head = linear_head
        self.dropout = nn.Dropout(0.1)
        self.gelu = New_gelu4npu()
    def forward(self, x, mask=None):
        if mask is None:
            mask = ~make_pad_mask(torch.tensor([x.size(1)]*x.size(0), device=x.device, dtype=torch.int64)).to(x.device)
        x, _ = self.encoder(x, mask)
        x, _ = self.encoder(x, mask)
        x = self.dropout(x)
        self.gelu(x)
        x = self.linear_head(x)
        return x

class LLMASR_Model(nn.Module):
    def __init__(self,
                 encoder,
                 encoder_output_dim,
                 llm_path,
                 lora=True, lora_alpha=32, lora_rank=8, lora_dropout=0.1,
                 is_inference=False,
                 downsample_rate=1,
                 adapter_type='lyz',
                 speech_token_num=0,
                 train_speech_out=False):
        """"""
        super().__init__()
        utils_file.logging_limit_print(f"instruct_version: LLMASR_Model init, is_inference={is_inference}, downsample_rate={downsample_rate}, adapter_type={adapter_type}, speech_token_num={speech_token_num}, train_speech_out={train_speech_out}")
        self.downsample_rate = downsample_rate

        self.encoder = encoder
        self.ln_speech = nn.LayerNorm(encoder_output_dim)

        # 连接层, 51.6M
        if adapter_type == 'gxl':
            self.speech_transformer = TransformerEncoder(
                input_size=encoder_output_dim,
                output_size=encoder_output_dim,
                attention_heads=4,
                linear_units=2560,
                num_blocks=4,
                dropout_rate=0.1,
                positional_dropout_rate=0.1,
                attention_dropout_rate=0.0,
                input_layer="linear",
                pos_enc_layer_type="abs_pos",
                normalize_before=True
            )
        else:
            self.speech_transformer = None

        self.llama_model = AutoModelForCausalLM.from_pretrained(
            llm_path,
            # torch_dtype=torch.float32 if is_inference else torch.float16,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            output_hidden_states=True,
        )

        self.max_length = 400
        self.min_length = 1
        self.num_beams = 4
        self.do_sample = True
        self.top_p = 0.9
        self.top_k = 5
        self.repetition_penalty = 1.05
        self.length_penalty = 1.0
        self.temperature = 1.0
        self.IGNORE_ID = -100

        # lora
        self.lora = lora
        if lora:
            utils_file.logging_limit_print("耿雪龙： 使用lora了")
            # target_modules = ['w_pack', 'o_proj', 'gate_proj', 'down_proj']
            target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj']
            self.peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=is_inference,
                r=lora_rank,
                lora_alpha=lora_alpha,
                lora_dropout=lora_dropout,
                target_modules=target_modules,
            )
            self.llama_model = get_peft_model(self.llama_model, self.peft_config)

        # tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(
            llm_path, use_fast=False, trust_remote_code=True)
        """
        设置分词器的pad_token和padding的方向。
        """
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        self.tokenizer.padding_side = "right"
        self.eos_token_id = self.tokenizer.eos_token_id

        if hasattr(self.llama_model.config, 'hidden_size'):
            utils_file.logging_limit_print(
                f"self.llama_model.config.hidden_size: {self.llama_model.config.hidden_size}")
            if adapter_type == 'lyz':
                self.down_sample_2 = LyzConv1dSubsampling(encoder_output_dim, self.llama_model.config.hidden_size)
            elif adapter_type == 'gxl':
                self.down_sample_2 = get_downsampler(downsample_rate, encoder_output_dim)
                self.speech_llama_proj = nn.Linear(
                    encoder_output_dim, self.llama_model.config.hidden_size)
        else:
            raise NotImplementedError("self.llama_model.config.hidden_size not exist")

        self.embed_tokens = self.llama_model.model.model.embed_tokens if self.lora else self.llama_model.model.embed_tokens
        self.lm_head = self.llama_model.model.lm_head if self.lora else self.llama_model.lm_head
        self.llm_vocab_size  = self.lm_head.weight.shape[0]
        self.speech_token_num = speech_token_num
        # init speech token module
        if speech_token_num > 0:
            utils_file.logging_info(f'耿雪龙： 进行语音token生成任务， speech_token_num: {speech_token_num}')
            self.speech_token_emded = torch.nn.Embedding(speech_token_num + 2, self.llama_model.config.hidden_size)
            self.speech_head = torch.nn.Linear(self.llama_model.config.hidden_size, speech_token_num)
            # self.speech_head = SpeechBigHead(
            #     TransformerEncoder2(
            #         input_size=self.llama_model.config.hidden_size,
            #         output_size=self.llama_model.config.hidden_size,
            #         attention_heads=8,
            #         linear_units=4096,
            #         num_blocks=8,
            #         dropout_rate=0.1,
            #         positional_dropout_rate=0.1,
            #         attention_dropout_rate=0.1,
            #         input_layer="linear",
            #         pos_enc_layer_type="abs_pos",
            #         normalize_before=True
            #     ),
            #     torch.nn.Linear(self.llama_model.config.hidden_size, speech_token_num, bias=False))
        else:
            # 不做任何处理
            self.speech_head = nn.Identity()
            self.speech_token_emded = nn.Identity()
            self.speech_model = nn.Identity()
        self.train_speech_out = train_speech_out
        utils_file.logging_info(f'耿雪龙： 是否进行语音输出训练：{self.train_speech_out}')
        self.loss_fct = CrossEntropyLoss(reduction='mean')
        self.unk_token_id = 7672 # &&对应的id

    def get_label_embedding(self, labels, labels_lengths, unk_id=7672):
        """"""
        labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
        unk_mask = (labels == unk_id)  # B, L
        labels_pad_mask = labels_pad_mask | unk_mask  #
        labels = labels.masked_fill(labels_pad_mask, 0)
        labels_embeds = self.embed_tokens(labels)
        labels_target = labels.masked_fill(labels_pad_mask, self.IGNORE_ID)  # B, L
        labels_mask = ~labels_pad_mask
        return labels_embeds, labels_target, labels_mask

    def get_speech_token_label_embedding(self, speech_token_labels, speech_tokens_length):
        """"""
        speech_tokens_pad_mask = make_pad_mask(speech_tokens_length)  # B, L
        speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
        speech_token_labels_embeds = self.speech_token_emded(speech_token_labels)
        # utils_file.logging_limit_print(f'进行speech_token_labels修改，修改前 speech_token_labels',
        #    speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0])
        speech_token_labels = speech_token_labels + self.llm_vocab_size
        # utils_file.logging_limit_print(f'进行speech_token_labels修改，修改后 speech_token_labels',
        #    speech_token_labels.shape, speech_token_labels[0][-1], speech_token_labels[0][0])
        speech_token_labels_target = speech_token_labels.masked_fill(speech_tokens_pad_mask, self.IGNORE_ID)  # B, L
        speech_token_labels_mask = ~speech_tokens_pad_mask
        return speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask

    def _get_embedding_for_history(self, history_batch, device):
        """
        prompt_patern1,prompt,history, wav, prompt_patern2,txt,answer_wav,
        historcy_batch的内容是：
        [   big_embed,
            [

                {wav: feat （L，D:80)，->过encoder+link ,得到（L1, 2048)
                 txt: labels (L，), ->labels_embeds = self.embed_tokens(labels) -> (L2, 2048)， 带txt eos
                },->(L1+L2, 2048)
                {wav: feat （L，D)，
                 txt: labels (L，),
                },->(L3+L4, 2048)
            ]-> (L1+L2+L3+L4, 2048)，len:L1+L2+L3+L4
            [],->(1, 2048) ,len:0
            [
                {wav: feat （L，D)，
                 txt: labels (L，),
                },
            ],-> (L1+L2, 2048),len:L1+L2
        ]
        将每一条的历史信息的embedding拼接起来,如果有空历史信息，则用0pad, 最后得到pad后的history_embedding(B, L, D), history_lens(B)
        Args:
            history_batch:
            device:

        Returns:
            history_embedding: B, L, D
            history_lens: B

        """
        assistant_start ="<|im_end|>\n<|im_start|>assistant\n"
        assistant_start_id = self.tokenizer([assistant_start], return_tensors="pt"
                                         )['input_ids'].to(device)
        assistant_start_embedding = self.embed_tokens(assistant_start_id.squeeze(0))
        assistant_end ="<|im_end|>\n"
        assistant_end_id = self.tokenizer([assistant_end], return_tensors="pt"
                                       )['input_ids'].to(device)
        assistant_end_embedding = self.embed_tokens(assistant_end_id.squeeze(0))
        user_start = "<|im_start|>user\n"
        user_start_id = self.tokenizer([user_start], return_tensors="pt"
                                        )['input_ids'].to(device)
        user_start_embedding = self.embed_tokens(user_start_id.squeeze(0))
        user_end ="<|im_end|>\n"
        user_end_id = self.tokenizer([user_end], return_tensors="pt"
                                      )['input_ids'].to(device)
        user_end_embedding = self.embed_tokens(user_end_id.squeeze(0))
        batch_embeddings = []
        history_lens = []
        # 判断是否所有样本都没有历史
        if all(len(history) == 0 for history in history_batch):
            return None, None

        for history in history_batch:
            history_embeds = []

            for item in history:
                wav_feat = item['wav'].to(device)  # shape: (L, D)
                wav_feat = wav_feat.unsqueeze(0).to(device) # shape: (1, L, D)
                wav_embed, wav_mask = self._get_embedding_from_wav(wav_feat, torch.tensor([wav_feat.size(1)], device=device, dtype=torch.long))
                wav_embed = wav_embed.squeeze(0)  # shape: (L, D)
                if len(history_embeds) != 0:
                    history_embeds.append(user_start_embedding) # 第一个user start 不要
                history_embeds.append(wav_embed)
                history_embeds.append(user_end_embedding)
                history_embeds.append(assistant_start_embedding)
                labels = item['txt']  # shape: (L,)
                labels = torch.tensor(labels, device=device, dtype=torch.long)
                embed = self.embed_tokens(labels)  # (L2, D)，一般 L2 = L
                history_embeds.append(embed)
                history_embeds.append(assistant_end_embedding)
            history_embeds.append(user_start_embedding) # 最后添加一个user start

            if history_embeds:
                # 拼接所有历史条目的 embedding: (sum(Li), D)
                full_embed = torch.cat(history_embeds, dim=0)
                history_lens.append(full_embed.size(0))
            else:
                # 空历史
                full_embed = torch.zeros((1, self.embed_tokens.embedding_dim), device=device)
                history_lens.append(0)

            batch_embeddings.append(full_embed)

        # padding 到 batch 中最大长度
        padded_embeddings = pad_sequence(batch_embeddings, batch_first=True, padding_value=0.0)  # (B, L, D)
        history_lens = torch.tensor(history_lens, device=device, dtype=torch.long)
        padded_embeddings = padded_embeddings.to(device)

        return padded_embeddings, history_lens


    def forward(self,
                batch,
                device,
                ):
        """"""
        output_type = batch['output_type']
        # qwen_instruct_prompt_pattern_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n"
        qwen_instruct_prompt_pattern_chat_s2s = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n"
        qwen_instruct_prompt_pattern_chat_s2s_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech. Before responding, first output your reasoning inside <think>...</think end>, analyzing the user’s words and vocal cues. Then generate a reply with appropriate text and emotionally matched synthetic speech.<|im_end|>\n"
        qwen_instruct_prompt_pattern_chat_s2s_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (content + paralinguistic cues) and respond with interleaved text and emotionally-matched synthetic speech.<|im_end|>\n"
        qwen_instruct_prompt_pattern_chat_s2s_streaming_think = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You analyze speech (both content and paralinguistic cues). Before responding, output your reasoning in <think>...</think end>. Then reply with interleaved text and emotionally matched synthetic speech.<|im_end|>\n"
        qwen_instruct_prompt_pattern_chat_s2t = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n"
        qwen_instruct_prompt_pattern__chat_t2t = "<|im_start|>system\nYou are OSUM-chat, a text-to-text dialogue assistant by ASLP Lab. You understand user input in text then respond exclusively with appropriate text.<|im_end|>\n"
        qwen_instruct_prompt_pattern_1_understand = "<|im_start|>system\nYou are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.<|im_end|>\n"
        qwen_instruct_prompt_pattern_1_tts = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural and fluent speech from text input.<|im_end|>\n"
        qwen_instruct_prompt_pattern_1_tts_streaming = "<|im_start|>system\nYou are OSUM-chat, a speech synthesis assistant by ASLP Lab. You generate natural speech from text input and output both audio and the original text in interleaved format.<|im_end|>\n"
        qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
        # user_start = "<|im_start|>user\n"

        # 赋予不同的系统提示。
        if output_type == "s2t_chat":
            system_prompt = qwen_instruct_prompt_pattern_chat_s2t
        elif output_type == "s2t_chat_fake":
            system_prompt = qwen_instruct_prompt_pattern_chat_s2s_think
        elif output_type == "text":
            system_prompt = qwen_instruct_prompt_pattern_1_understand
        elif output_type == "speech2text_token" or output_type == "speech2text_token_history":
            system_prompt = qwen_instruct_prompt_pattern_chat_s2s
        elif output_type == "text2token":
            system_prompt = qwen_instruct_prompt_pattern_1_tts
        elif output_type == "speech2text_token_streaming":
            system_prompt = qwen_instruct_prompt_pattern_chat_s2s_streaming
        elif output_type == "speech2text_token_think":
            system_prompt = qwen_instruct_prompt_pattern_chat_s2s_think
        elif output_type == "text2token_streaming":
            system_prompt = qwen_instruct_prompt_pattern_1_tts_streaming
        elif output_type == "text2text":
            system_prompt = qwen_instruct_prompt_pattern__chat_t2t
        else:
            system_prompt = qwen_instruct_prompt_pattern_1_old
        # if output_type == "speech2text_token_history":
        # if output_type == "text2text" or output_type == "text":
        #     qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_old
        # elif output_type == "speech2text_token" or output_type == "speech2text_token_streaming" or output_type == "text2text" or output_type == "s2t_chat":
        #     qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_chat
        # elif output_type == "text2token":
        #     qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_tts
        # else:
        #     qwen_instruct_prompt_pattern_1 = qwen_instruct_prompt_pattern_1_old
        system_prompt = system_prompt +  "<|im_start|>user\n"

        rank = int(os.environ.get('RANK', 0))
        utils_file.logging_limit_print(f'xxx output_type {output_type}, rank {rank}')
        # if output_type == "s2t_chat":
        #     output_type = "text"
        # assert output_type in ['text', 'speech2text_token', 'text2token'], f"output_type:{output_type} not support"
        # speech inputs
        if output_type == 'text' or output_type == 's2t_chat' or output_type == 's2t_chat_fake' or output_type == 'speech2text_token' or output_type == "speech2text_token_streaming" or output_type == "speech2text_token_think" or output_type == "speech2text_token_history":
            wavs = batch['feats'].to(device)
            # utils_file.logging_limit_print(f'xxx wav shape {wavs.shape}')
            wavs_len = batch['feats_lengths'].to(device)
            B = wavs.shape[0]
            # utils_file.logging_limit_print(f"xxx {wavs_len}")
            speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
            # utils_file.logging_limit_print(f'xxx speech embeding shape {speech_embeds.shape}')
            # utils_file.logging_limit_print(f'xxx speech mask shape {speech_masks.shape}')
            # utils_file.logging_limit_print(f'xxx speech mask 0 {speech_masks[0]}')
            speech_target = torch.full(speech_masks.shape, self.IGNORE_ID).to(
                speech_embeds.device)
            # utils_file.logging_limit_print(f'xxx speech target shape {speech_target.shape}')
            # utils_file.logging_limit_print(f'xxx speech target 0 {speech_target[0]}')
            # add bos and eos
            speech_embeds, speech_masks, speech_target = self._add_bos_eos(0+self.speech_token_num,
                                                                           1+self.speech_token_num,
                                                                           speech_embeds, speech_masks, speech_target)
        elif output_type == "text2token" or output_type == "text2token_streaming":
            labels = batch['target'].to(device)
            labels_lengths = batch['target_lengths'].to(device) -1  # 减1是因为要去掉eos
            B = labels.shape[0]
            #  text 2 token ,拿到文本序列,
            max_len = max(labels_lengths) + 1
            labels_pad_mask = make_pad_mask(labels_lengths, max_len=max_len)
            labels = labels.masked_fill(labels_pad_mask, 0)
            speech_embeds = self.embed_tokens(labels)  # B, L, D
            speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
                speech_embeds.device)
            speech_masks = ~labels_pad_mask

            # add bos and eos
            # speech_embeds, speech_masks, speech_target = self._add_bos_eos(0+self.speech_token_num,
            #                                                                1 + self.speech_token_num,
            #                                                                speech_embeds, speech_masks, speech_target)
        else: # text2text
            speech_embeds = None
            speech_masks = None
            speech_target = None
        # utils_file.logging_limit_print(f'xxx after add bos eos speech embeding shape {speech_embeds.shape}')
        # utils_file.logging_limit_print(f'xxx after add bos eos speech mask shape {speech_masks.shape}')
        # utils_file.logging_limit_print(f'xxx after add bos eos speech target shape {speech_target.shape}')
        # utils_file.logging_limit_print(f'xxx after add bos eos speech mask 0 {speech_masks[0]}')
        # utils_file.logging_limit_print(f'xxx after add bos eos speech target 0 {speech_target[0]}')

        # prompt
        if 'prompt' in batch:
            prompt = batch['prompt'].to(device)
            prompt_lengths = batch['prompt_lengths'].to(device)
            prompt_pad_mask = make_pad_mask(prompt_lengths)  # B, L
            prompt = prompt.masked_fill(prompt_pad_mask, self.tokenizer.eos_token_id)
            prompt_embeds = self.embed_tokens(prompt)  # B, L, D
            prompt_target = torch.full(prompt.shape, self.IGNORE_ID).to(
                device)  # B, L
            prompt_mask = ~prompt_pad_mask
            # utils_file.logging_limit_print(f'xxx prompt embeding shape {prompt_embeds.shape}')
            # utils_file.logging_limit_print(f'xxx prompt mask shape {prompt_mask.shape}')
            # utils_file.logging_limit_print(f'xxx prompt target shape {prompt_target.shape}')
        else:
            prompt_embeds = None
            prompt_mask = None
            prompt_target = None

        inputs_embeds_list = []
        attention_mask_list = []
        target_list = []
        prompt_pattern1 = self.tokenizer([system_prompt] * len(batch['target']), return_tensors="pt"
                                         )['input_ids'].to(device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
        prompt_pattern1_lens = torch.tensor([len(i) for i in prompt_pattern1]).to(device)
        prompt_pattern1_mask = ~make_pad_mask(prompt_pattern1_lens)
        prompt_pattern1_target = torch.full(prompt_pattern1.shape, self.IGNORE_ID).to(
            device)  # B, L

        # user_start_id = self.tokenizer([user_start] * len(batch['target']), return_tensors="pt"
        #                                 )['input_ids'].to(device)
        # user_start_embeds = self.embed_tokens(user_start_id)
        # user_start_lens = torch.tensor([len(i) for i in user_start_id]).to(device)
        # user_start_mask = ~make_pad_mask(user_start_lens)
        # user_start_target = torch.full(user_start_id.shape, self.IGNORE_ID).to(
        #     device)  # B, L

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(batch['target']), return_tensors="pt"
                                         )['input_ids'].to(device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
        prompt_pattern2_lens = torch.tensor([len(i) for i in prompt_pattern2]).to(device)
        prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens)
        prompt_pattern2_target = torch.full(prompt_pattern2.shape, self.IGNORE_ID).to(
            device)  # B, L

        inputs_embeds_list.append(prompt_pattern1_embeds)
        attention_mask_list.append(prompt_pattern1_mask)
        target_list.append(prompt_pattern1_target)
        streaming_error = False
        if output_type == "speech2text_token_streaming":
            rank = int(os.environ.get('RANK', 0))
            utils_file.logging_limit_print(f'开始处理speech2text_token streaming 任务')
            labels = batch['target'].to(device)
            labels_lengths = batch['target_lengths'].to(device)
            speech_token_labels = batch['speech_tokens'].to(device)
            speech_tokens_length = batch['speech_tokens_length'].to(device)

            labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
            labels = labels.masked_fill(labels_pad_mask, 0)

            speech_tokens_pad_mask = make_pad_mask(speech_tokens_length)  # B, L
            speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
            speech_token_labels = speech_token_labels + self.llm_vocab_size
            if rank == 0:
                utils_file.logging_limit_print(f'labels.shape {labels.shape}')
                utils_file.logging_limit_print(f'labels_lengths.shape {labels_lengths.shape}')
                utils_file.logging_limit_print(f'labels[0] {labels[0]}')
                utils_file.logging_limit_print(f'------------------------')
                utils_file.logging_limit_print(f'speech_token_labels.shape {speech_token_labels.shape}')
                utils_file.logging_limit_print(f'speech_tokens_length.shape {speech_tokens_length.shape}')
                utils_file.logging_limit_print(f'speech_token_labels[0] {speech_token_labels[0]}')
                utils_file.logging_limit_print(f'==========================')
            streaming_concat_ids, streaming_concat_lens = make_streaming_mode_from_s2s(labels, labels_lengths,
                                                                                       speech_token_labels,
                                                                                       speech_tokens_length)
            if rank == 0:
                utils_file.logging_limit_print(f'streaming_concat_ids.shape {streaming_concat_ids.shape}')
                utils_file.logging_limit_print(f'streaming_concat_lens.shape {streaming_concat_lens.shape}')
                utils_file.logging_limit_print(f'streaming_concat_lens {streaming_concat_lens[0]}')
                utils_file.logging_limit_print(f'xxx streaming_concat_ids[0] : {streaming_concat_ids[0]}')
                utils_file.logging_limit_print(f'------------------------')
            streaming_concat_embeddings = do_embedding_for_two_embeds(streaming_concat_ids, self.llm_vocab_size, self.embed_tokens,
                                                                      self.speech_token_emded)
            streaming_concat_pad_mask = make_pad_mask(streaming_concat_lens)
            streaming_concat_target = streaming_concat_ids.masked_fill(streaming_concat_pad_mask, self.IGNORE_ID)
            streaming_concat_mask = ~streaming_concat_pad_mask
            if rank == 0:
                utils_file.logging_limit_print(f'streaming_concat_embeddings.shape {streaming_concat_embeddings.shape}')
                utils_file.logging_limit_print(f'streaming_concat_mask shape {streaming_concat_mask.shape}')
                utils_file.logging_limit_print(f'------------------------')
            # if prompt_embeds is not None:  # 对于s2s 对话任务，不再使用user prompt 输入
            #     inputs_embeds_list.append(prompt_embeds)
            #     attention_mask_list.append(prompt_mask)
            #     target_list.append(prompt_target)

            # ===================history===================================
            history_batch = batch.get('history', [])
            history_embedding, history_lens = self._get_embedding_for_history(history_batch, device)
            if history_embedding is not None:
                utils_file.logging_info(f'耿雪龙： 进行历史信息的embedding')
                history_pad_mask = make_pad_mask(history_lens)  # B, L
                history_target = torch.full(history_pad_mask.shape, self.IGNORE_ID).to(device)  # B, L
                history_mask = ~history_pad_mask
                inputs_embeds_list.append(history_embedding)
                attention_mask_list.append(history_mask)
                target_list.append(history_target)
                utils_file.logging_limit_print(f'xxx history embeding shape {history_embedding.shape}')
                utils_file.logging_limit_print(f'xxx history mask shape {history_mask.shape}')
                utils_file.logging_limit_print(f'xxx history target shape {history_target.shape}')
            else:
                utils_file.logging_limit_print(f'history is None')
            # ==========================history end ===================
            inputs_embeds_list.extend(
                [  speech_embeds, prompt_pattern2_embeds, streaming_concat_embeddings])
            attention_mask_list.extend([speech_masks, prompt_pattern2_mask, streaming_concat_mask])
            target_list.extend([speech_target, prompt_pattern2_target, streaming_concat_target])
        elif output_type == "text2token_streaming":
            rank = int(os.environ.get('RANK', 0))
            utils_file.logging_limit_print(f'开始tts streaming 任务')
            labels = batch['target'].to(device)
            labels_lengths = batch['target_lengths'].to(device)
            speech_token_labels = batch['speech_tokens'].to(device)
            speech_tokens_length = batch['speech_tokens_length'].to(device)

            labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
            labels = labels.masked_fill(labels_pad_mask, 0)

            speech_tokens_pad_mask = make_pad_mask(speech_tokens_length)  # B, L
            speech_token_labels = speech_token_labels.masked_fill(speech_tokens_pad_mask, 0)
            speech_token_labels = speech_token_labels + self.llm_vocab_size
            streaming_concat_ids, streaming_concat_lens = make_streaming_mode_from_s2s(labels, labels_lengths,
                                                                                       speech_token_labels,
                                                                                       speech_tokens_length)
            streaming_concat_embeddings = do_embedding_for_two_embeds(streaming_concat_ids, self.llm_vocab_size,
                                                                      self.embed_tokens,
                                                                      self.speech_token_emded)
            streaming_concat_pad_mask = make_pad_mask(streaming_concat_lens)
            streaming_concat_target = streaming_concat_ids.masked_fill(streaming_concat_pad_mask, self.IGNORE_ID)
            streaming_concat_mask = ~streaming_concat_pad_mask
            # if prompt_embeds is not None: # 对于tts 对话任务，不再使用user prompt 输入
            #     inputs_embeds_list.append(prompt_embeds)
            #     attention_mask_list.append(prompt_mask)
            #     target_list.append(prompt_target)
            inputs_embeds_list.extend(
                [ speech_embeds, prompt_pattern2_embeds, streaming_concat_embeddings])
            attention_mask_list.extend([speech_masks, prompt_pattern2_mask, streaming_concat_mask])
            target_list.extend([speech_target, prompt_pattern2_target, streaming_concat_target])

        elif output_type == 'speech2text_token' or output_type == "speech2text_token_think" or output_type == "speech2text_token_history":
            utils_file.logging_limit_print(f'xxx 开始处理speech2text_token任务')
            labels = batch['target'].to(device)
            labels_lengths = batch['target_lengths'].to(device)
            speech_token_labels = batch['speech_tokens'].to(device)
            speech_tokens_length = batch['speech_tokens_length'].to(device)

            labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
            speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding(
                speech_token_labels, speech_tokens_length)
            # if prompt_embeds is not None: # 对于s2s 对话任务，不再使用user prompt 输入
            #     inputs_embeds_list.append(prompt_embeds)
            #     attention_mask_list.append(prompt_mask)
            #     target_list.append(prompt_target)
            # ===================history===================================
            history_batch = batch.get('history', [])
            history_embedding, history_lens = self._get_embedding_for_history(history_batch, device)
            if history_embedding is not None:
                utils_file.logging_info(f'耿雪龙： 进行历史信息的embedding')
                history_pad_mask = make_pad_mask(history_lens)  # B, L
                history_target = torch.full(history_pad_mask.shape, self.IGNORE_ID).to(device)  # B, L
                history_mask = ~history_pad_mask
                inputs_embeds_list.append(history_embedding)
                attention_mask_list.append(history_mask)
                target_list.append(history_target)
                utils_file.logging_limit_print(f'xxx history embeding shape {history_embedding.shape}')
                utils_file.logging_limit_print(f'xxx history mask shape {history_mask.shape}')
                utils_file.logging_limit_print(f'xxx history target shape {history_target.shape}')
            else:
                utils_file.logging_limit_print(f'history is None')
            # ==========================history end ===================
            inputs_embeds_list.extend(
                [ speech_embeds, prompt_pattern2_embeds, labels_embeds, speech_token_labels_embeds])
            attention_mask_list.extend([speech_masks, prompt_pattern2_mask, labels_mask, speech_token_labels_mask])
            target_list.extend([speech_target, prompt_pattern2_target, labels_target, speech_token_labels_target])
        elif output_type == "text2token":
            speech_token_labels = batch['speech_tokens'].to(device)
            speech_tokens_length = batch['speech_tokens_length'].to(device)
            speech_token_labels_embeds, speech_token_labels_target, speech_token_labels_mask = self.get_speech_token_label_embedding(
                speech_token_labels, speech_tokens_length)
            # if prompt_embeds is not None: # 对于tts 对话任务，不再使用user prompt 输入
            #     inputs_embeds_list.append(prompt_embeds)
            #     attention_mask_list.append(prompt_mask)
            #     target_list.append(prompt_target)
            inputs_embeds_list.extend([ speech_embeds, prompt_pattern2_embeds, speech_token_labels_embeds])
            attention_mask_list.extend([speech_masks, prompt_pattern2_mask, speech_token_labels_mask])
            target_list.extend([speech_target, prompt_pattern2_target, speech_token_labels_target])
        elif output_type == "text" or output_type == 's2t_chat' or output_type == "s2t_chat_fake":
            labels = batch['target'].to(device)
            labels_lengths = batch['target_lengths'].to(device)
            labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
            if prompt_embeds is not None and output_type == 'text': # 对于s2t_chat 对话任务，不再使用user prompt 输入
                inputs_embeds_list.append(prompt_embeds)
                attention_mask_list.append(prompt_mask)
                target_list.append(prompt_target)
            elif output_type != 's2t_chat' or output_type != "s2t_chat_fake":
                utils_file.logging_limit_print(
                    f'prompt is None,task: {batch["task"]}, prompt_embeds:{prompt_embeds}, prompt_mask:{prompt_mask}')
            inputs_embeds_list.extend([ speech_embeds, prompt_pattern2_embeds, labels_embeds])
            attention_mask_list.extend([speech_masks, prompt_pattern2_mask, labels_mask])
            target_list.extend([speech_target, prompt_pattern2_target, labels_target])
        elif output_type == "text2text":
            labels = batch['target'].to(device)
            labels_lengths = batch['target_lengths'].to(device)
            labels_embeds, labels_target, labels_mask = self.get_label_embedding(labels, labels_lengths)
            if prompt_embeds is not None:
                inputs_embeds_list.append(prompt_embeds)
                attention_mask_list.append(prompt_mask)
                target_list.append(prompt_target)
            else:
                utils_file.logging_limit_print(
                    f'prompt is None,task: {batch["task"]}, prompt_embeds:{prompt_embeds}, prompt_mask:{prompt_mask}')
            inputs_embeds_list.extend([ prompt_pattern2_embeds, labels_embeds])
            attention_mask_list.extend([ prompt_pattern2_mask, labels_mask])
            target_list.extend([ prompt_pattern2_target, labels_target])

        else:
            raise NotImplementedError(f'output_type {output_type} not support')

        inputs_embeds = torch.cat(inputs_embeds_list, dim=1)
        # utils_file.logging_limit_print(f'xxx final inputs_embeds shape {inputs_embeds.shape}')
        attention_mask = torch.cat(attention_mask_list, dim=1)
        # utils_file.logging_limit_print(f'xxx final attention_mask shape {attention_mask.shape}')
        # utils_file.logging_limit_print(f'xxx final attention_mask 0 {attention_mask[0]}')
        target = torch.cat(target_list, dim=1)
        # utils_file.logging_limit_print(f'xxx final  target shape {target.shape}')
        # utils_file.logging_limit_print(f'xxx final target 0 {target[0]}')
        # utils_file.logging_limit_print(f'耿雪龙 output_type: {output_type}')
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        # utils_file.logging_limit_print(f'xxx final position_ids shape {position_ids.shape}')
        # utils_file.logging_limit_print(f'xxx final position_ids 0 {position_ids[0]}')
        if output_type == 'text' or output_type == 's2t_chat' or output_type == "s2t_chat_fake" or output_type == "text2text":
            outputs = self.llama_model(
                inputs_embeds=inputs_embeds,
                labels=target,
                attention_mask=attention_mask,
                position_ids=position_ids.to(inputs_embeds.device)
            )
            loss = outputs['loss']
            return {"loss": loss,"output_type": output_type}
        else:
            utils_file.logging_limit_print(f'进行llama_model的 diy forward')
            outputs = self.llama_model(
                inputs_embeds=inputs_embeds,
                # labels=target,
                attention_mask=attention_mask,
                position_ids=position_ids.to(inputs_embeds.device)
            )
            hidden_states = outputs['hidden_states'][-1]
            logits = self.lm_head(hidden_states)
            logits2 = self.speech_head(hidden_states)  # speech_head
            combined_logits = torch.cat([logits, logits2], dim=-1)
            # combined_logits = self.new_lm_head(hidden_states)
            shift_logits = combined_logits[..., :-1, :].contiguous()
            shift_target = target[..., 1:].contiguous()
            # utils_file.logging_limit_print(
            #     f'xxx shift_logits shape: {shift_logits.shape}, shift_target shape: {shift_target.shape}')
            # utils_file.logging_limit_print(f'xxx shift_target 0 {shift_target[0]}')
            shift_logits = shift_logits.view(-1, combined_logits.shape[-1])  # 注意这里维度的调整，根据logits2的维度相应改变
            shift_target = shift_target.view(-1)
            shift_target = shift_target.to(shift_logits.device)
            loss = self.loss_fct(shift_logits, shift_target)
            loss.requires_grad_(True)
            return {"loss": loss,"output_type": output_type}


    def generate_s2s_streaming(
            self,
            wavs,
            wavs_len,
            prompt,
    ):
        self.llama_model.eval()
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num,
                                                           speech_embeds, speech_masks, None)

        device = speech_embeds.device

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)

        qwen_instruct_prompt_pattern_1_s_input_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant created by ASLP Lab. You understand both the meaning and paralinguistic cues in users' speech, and respond appropriately with text or voice.<|im_end|>\n<|im_start|>user\n"

        qwen_instruct_prompt_pattern_1_t2t_chat = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant created by ASLP Lab. You understand user input in text and respond with accurate and helpful text replies.<|im_end|>\n<|im_start|>user\n"
        qwen_instruct_prompt_pattern_1_old = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        qwen_instruct_prompt_pattern_1 =qwen_instruct_prompt_pattern_1_old
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1], return_tensors="pt"
                                         )['input_ids'].to(device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)
        prompt_pattern1_lens = torch.tensor([len(i) for i in prompt_pattern1]).to(device)
        prompt_pattern1_mask = ~make_pad_mask(prompt_pattern1_lens)
        prompt_pattern1_target = torch.full(prompt_pattern1.shape, self.IGNORE_ID).to(
            device)  # B, L

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] , return_tensors="pt"
                                         )['input_ids'].to(device)
        start_id = prompt_pattern2[0][-1]
        new_prompt_pattern2 = prompt_pattern2[:,:-1]
        prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2)
        prompt_pattern2_lens = torch.tensor([len(i) for i in new_prompt_pattern2],dtype=torch.long).to(device)
        prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens)
        prompt_pattern2_target = torch.full(new_prompt_pattern2.shape, self.IGNORE_ID).to(
            device)  # B, L

        embeds = torch.cat([prompt_pattern1_embeds,prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)
        max_len = 350
        hyps = [start_id]
        print(f'start_id: {start_id}')
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )

        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1
        cache = llm_out.past_key_values
        token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        inferring_txt = True
        txt_finished = False
        speeech_finished = False
        hyps_text = ""
        speech_eos_num = 0
        txt_list = []
        token_list = []
        i_num = 0
        for i in range(max_len):
            if inferring_txt and not txt_finished:
                for i_txt in range(6):
                    i_num += 1
                    if i_num> 300:
                        break
                    llm_out = self.llama_model(
                        inputs_embeds=token_emb,
                        past_key_values=cache,
                        output_hidden_states=True
                    )
                    cache = llm_out.past_key_values
                    hidden_states = llm_out.hidden_states[-1]
                    token_logits = self.lm_head(hidden_states)
                    next_token_ids = self._sampler(
                        token_logits,
                        temperatures_tensor,
                        top_ps_tensor,
                        top_ks_tensor,
                    )
                    # next_token_ids = torch.argmax(token_logits, dim=-1)
                    print(i_num, next_token_ids, f'txt')
                    hyps.append(next_token_ids.item())
                    token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
                    if next_token_ids == self.eos_token_id:
                        txt_finished = True
                        hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
                        print("hyps_text:", hyps_text)
                        print("text is over")
                        break
                    txt_list.append(next_token_ids.item())
                hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
                print("hyps_text:", hyps_text)
                inferring_txt = False
            elif not speeech_finished:
                for i_speech in range(18):
                    i_num += 1
                    if i_num> 300:
                        break
                    llm_out = self.llama_model(
                        inputs_embeds=token_emb,
                        past_key_values=cache,
                        output_hidden_states=True
                    )
                    cache = llm_out.past_key_values
                    hidden_states = llm_out.hidden_states[-1]
                    token_logits = self.speech_head(hidden_states)
                    next_token_ids = self._sampler(
                        token_logits,
                        temperatures_tensor,
                        top_ps_tensor,
                        top_ks_tensor,
                    )
                    # next_token_ids = torch.argmax(token_logits, dim=-1)
                    hyps.append(next_token_ids.item())
                    print(i_num, next_token_ids, f'speech')
                    token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
                    if next_token_ids == 4096:
                        speech_eos_num += 1
                        print(f'遇到 4096')
                        if speech_eos_num >= 2:
                            print("speech is over")
                            speeech_finished = True
                        break
                    token_list.append(next_token_ids.item())
                    inferring_txt = True
            if speeech_finished:
                break
            if i_num > 300:
                break
        return [hyps_text + "|" + str(token_list)]

    def generate(
            self,
            wavs,
            wavs_len,
            prompt,
            **kwargs
    ):
        self.llama_model.eval()
        # self.set_task_type("ASR")
        # self.do_add_speech_embed_head()
        # self.do_merge_embed_head()
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
                                                           speech_embeds, speech_masks, None)
        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)

        # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        # # sft
        # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n"
        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, an audio understanding assistant by ASLP Lab. You can transcribe speech accurately and analyze paralinguistic cues to provide precise text responses.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)

        embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)

        if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            # embeds = embeds.to(torch.float16)
            embeds = embeds.to(torch.bfloat16)
            atts = atts.to(torch.bfloat16)
        outputs = self.llama_model.generate(
            inputs_embeds=embeds,
            max_new_tokens=self.max_length,
            # cache_implementation="static",
            # num_beams=self.num_beams,
            do_sample=self.do_sample,
            min_length=self.min_length,
            top_p=self.top_p,
            top_k=self.top_k,
            repetition_penalty=self.repetition_penalty,
            length_penalty=self.length_penalty,
            temperature=self.temperature,
            # attention_mask=atts,
            eos_token_id=151645,
            pad_token_id=-100,
            # stopping_criteria=self.max_token_criteria_list,
            # do_compile=True,
        )
        output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)

        return output_text

    def generate4chat(
            self,
            wavs,
            wavs_len,
            prompt=" ",
            do_sample=True,
            top_k=2,
            top_p=1,
            temperature=0.4,
            **kwargs
    ):
        print(f'do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}')
        self.llama_model.eval()
        # self.set_task_type("ASR")
        # self.do_add_speech_embed_head()
        # self.do_merge_embed_head()
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
                                                           speech_embeds, speech_masks, None)
        # prompt = self.tokenizer([prompt], return_tensors="pt"
        #                         )['input_ids'].to(speech_embeds.device)
        # prompt_embeds = self.embed_tokens(prompt)

        # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        # # sft
        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-text dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond exclusively with appropriate text.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)

        embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)

        if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            # embeds = embeds.to(torch.float16)
            embeds = embeds.to(torch.bfloat16)
            atts = atts.to(torch.bfloat16)
        outputs = self.llama_model.generate(
            inputs_embeds=embeds,
            max_new_tokens=self.max_length,
            # cache_implementation="static",
            # num_beams=1,
            do_sample=do_sample,
            min_length=self.min_length,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=self.repetition_penalty,
            length_penalty=1,
            temperature=temperature,
            # attention_mask=atts,
            eos_token_id=151645,
            pad_token_id=-100,
            do_compile=True,
            stopping_criteria=self.max_token_criteria_list,
        )

        output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)

        return output_text


    def generate_s2s_no_stream(
            self,
            wavs,
            wavs_len,
            prompt,
    ):
        self.llama_model.eval()
        # self.set_task_type("S2S")
        # self.do_add_speech_embed_head()
        # =====================准备input embedding=====================
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, None,
                                                           speech_embeds, speech_masks, None)

        device = speech_embeds.device

        # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech, as well as input text, and respond appropriately.<|im_end|>\n<|im_start|>user\n"
        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are OSUM-chat, a speech-to-speech dialogue assistant by ASLP Lab. You understand both the meaning and paralinguistic cues in speech then respond with appropriate text and emotionally matching synthetic speech.<|im_end|>\n<|im_start|>user\n"
        # qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)

        hyps = [4098]
        token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)

        embeds = torch.cat(
            [prompt_pattern1_embeds, speech_embeds, token_emb, prompt_pattern2_embeds],
            dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)

        # max_len = 350
        top_k = 10
        top_p = 0.9
        temperature = 1.2
        invalid_eos = 10000000

        llm_out = self.llama_model.generate(inputs_embeds=embeds,
                                            max_new_tokens=self.max_length,
                                            eos_token_id=invalid_eos,
                                            cache_implementation="static",
                                            do_sample=True,
                                            temperature=temperature,
                                            top_k=top_k,
                                            top_p=top_p,
                                            stopping_criteria=self.s2s_stop_criteria,
                                            do_compile=True,
                                            repetition_penalty=1.0,
                                            )

        text_eos_idx = (llm_out[0] == 151645).nonzero(as_tuple=True)[0][0].item()
        text_res = llm_out[:, :text_eos_idx - 1]
        speech_res = llm_out[:, text_eos_idx + 1:-1]

        # print("llm_out", llm_out)
        output_text = self.tokenizer.batch_decode(text_res, add_special_tokens=False, skip_special_tokens=True)
        print(f'output_text:{output_text}')
        print(f'speech_res:{speech_res}')
        return (output_text, text_res, speech_res)


    # 处理token，为英文单词前加上空格
    # processed_tokens = []
    # for token in tokens:
    #     # 检查是否为英文单词（简单判断：是否全部由字母组成）
    #     if token.isalpha() and token[0].isascii():
    #         processed_tokens.append(" " + token)  # 英文单词前加空格
    #     else:
    #         processed_tokens.append(token)  # 其他token保持不变
    # output_text = "".join(processed_tokens)

    # 获取生成的token IDs
    # token_ids = outputs[0].tolist()  # 假设batch_size=1，取第一个输出
    # 将token IDs转换为字符串
    # tokens = [self.tokenizer.decode([token_id], skip_special_tokens=True) for token_id in token_ids]
    # 打印token列表和字符串列表
    # print("Token IDs:", token_ids)
    # print("Tokens:", tokens)

    # 使用tokenizer将token IDs批量转换为字符串
    # output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
    # print("Output Text:", output_text)

    def generate_s2s(
            self,
            wavs,
            wavs_len,
            prompt,
    ):
        self.llama_model.eval()
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 + self.speech_token_num, 1 + self.speech_token_num,
                                                           speech_embeds, speech_masks, None)

        device = speech_embeds.device

        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(wavs_len), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        new_prompt_pattern2 = prompt_pattern2[:, :-1]
        prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2)

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)
        embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds.to(device), speech_embeds, prompt_pattern2_embeds],
                           dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)
        max_len = 320
        hyps = [prompt_pattern2[0][-1]]
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )

        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1.2
        cache = llm_out.past_key_values
        token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        is_speech_token = False
        hyps_text = ""
        speech_eos_num = 0
        for i in range(max_len):
            llm_out = self.llama_model(
                inputs_embeds=token_emb,
                past_key_values=cache,
                output_hidden_states=True
            )
            cache = llm_out.past_key_values
            hidden_states = llm_out.hidden_states[-1]
            if is_speech_token:
                token_logits = self.speech_head(hidden_states)  # (B, )
            else:
                token_logits = self.lm_head(hidden_states)
            next_token_ids = self._sampler(
                token_logits,
                temperatures_tensor,
                top_ps_tensor,
                top_ks_tensor,
            )
            print(i, next_token_ids, f'is_speech_token:{is_speech_token}')
            if next_token_ids == self.eos_token_id:
                print("text is over")
                print("hyps:", hyps)
                is_speech_token = True
                hyps_text = self.tokenizer.decode(hyps[1:], skip_special_tokens=True, add_special_tokens=False)
                print("hyps_text:", hyps_text)
                hyps = []
            if is_speech_token and next_token_ids == self.speech_token_num - 1:
                speech_eos_num += 1
                print(f'遇到 4096')
                if speech_eos_num >= 2:
                    print("break la!")
                    print("hyps:", hyps)
                    break
            hyps.append(next_token_ids.item())
            # if 1+i >= len(speech_token[0]):
            #     break
            # token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0)
            if next_token_ids != self.eos_token_id and is_speech_token:
                token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
            else:
                token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        res = []
        for i in hyps[2:]:
            # if i != self.speech_token_num-1:
            res.append(i)
        print(res)
        return [hyps_text + "|" + str(res)]

    def generate_tts(
            self,
            device,
            prompt,
            text,
    ):
        self.llama_model.eval()
        # labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device)
        # labels = text[:,:]
        labels = self.tokenizer(
            [text],
            return_tensors="pt",
            add_special_tokens=False
        ).to(
            self.embed_tokens.weight.device).input_ids  # (1, L)
        labels = labels.to(device)
        labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device)
        print(f'label_lengths:{labels_lengths}')
        print(f'labels:{labels}')
        labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
        labels = labels.masked_fill(labels_pad_mask, 0)
        speech_embeds = self.embed_tokens(labels)  # B, L, D
        speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
            speech_embeds.device)
        speech_masks = ~labels_pad_mask
        speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
                                                                       1 + self.speech_token_num,
                                                                       speech_embeds, speech_masks, speech_target)

        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)
        embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        # embeds = embeds[:,:-1,:]

        max_len = 250
        # hyps = [prompt_pattern2[0][-1]]
        hyps = [4096]
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )
        cache_hidden_states = llm_out.hidden_states[-1]
        print(f'cache_hidden_states:{cache_hidden_states.shape}') # (1, 51, 2048)
        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1.0
        cache = llm_out.past_key_values
        token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)

        for i in range(max_len):
            llm_out = self.llama_model(
                inputs_embeds=token_emb,
                past_key_values=cache,
                output_hidden_states=True
            )
            cache = llm_out.past_key_values
            hidden_states = llm_out.hidden_states[-1] # -1是指最后一层
            token_logits = self.speech_head(hidden_states)
            print(f'token_logits:{token_logits.shape}')
            next_token_ids = self._sampler(
                token_logits,
                temperatures_tensor,
                top_ps_tensor,
                top_ks_tensor,
            )
            print(i, next_token_ids)
            if next_token_ids == self.speech_token_num - 1:
                print("break la!")
                print("hyps:", hyps)
                break
            hyps.append(next_token_ids.item())
            token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        res = []
        for i in hyps[1:]:
            res.append(i)
        print(res)
        return [res]

    def generate_tts_streaming(
            self,
            device,
            prompt,
            text,
    ):
        self.llama_model.eval()
        # labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device)
        # labels = text[:,:]
        labels = self.tokenizer(
            [text],
            return_tensors="pt",
            add_special_tokens=False
        ).to(
            self.embed_tokens.weight.device).input_ids  # (1, L)
        labels = labels.to(device)
        labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device)
        print(f'label_lengths:{labels_lengths}')
        print(f'labels:{labels}')
        labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
        labels = labels.masked_fill(labels_pad_mask, 0)
        speech_embeds = self.embed_tokens(labels)  # B, L, D
        speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
            speech_embeds.device)
        speech_masks = ~labels_pad_mask
        speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
                                                                       1 + self.speech_token_num,
                                                                       speech_embeds, speech_masks, speech_target)

        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)
        # embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        # ----------------
        start_id = prompt_pattern2[0][-1]
        new_prompt_pattern2 = prompt_pattern2[:, :-1]
        prompt_pattern2_embeds = self.embed_tokens(new_prompt_pattern2)
        prompt_pattern2_lens = torch.tensor([len(i) for i in new_prompt_pattern2], dtype=torch.long).to(device)
        prompt_pattern2_mask = ~make_pad_mask(prompt_pattern2_lens)
        prompt_pattern2_target = torch.full(new_prompt_pattern2.shape, self.IGNORE_ID).to(
            device)  # B, L

        embeds = torch.cat([prompt_pattern1_embeds, prompt_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)
        max_len = 350
        hyps = [start_id]
        print(f'start_id: {start_id}')
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )

        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1
        cache = llm_out.past_key_values
        token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        inferring_txt = True
        txt_finished = False
        speeech_finished = False
        hyps_text = ""
        speech_eos_num = 0
        txt_list = []
        token_list = []
        i_num = 0
        for i in range(max_len):
            if inferring_txt and not txt_finished:
                for i_txt in range(6):
                    i_num += 1
                    if i_num > 300:
                        break
                    llm_out = self.llama_model(
                        inputs_embeds=token_emb,
                        past_key_values=cache,
                        output_hidden_states=True
                    )
                    cache = llm_out.past_key_values
                    hidden_states = llm_out.hidden_states[-1]
                    token_logits = self.lm_head(hidden_states)
                    next_token_ids = self._sampler(
                        token_logits,
                        temperatures_tensor,
                        top_ps_tensor,
                        top_ks_tensor,
                    )
                    # next_token_ids = torch.argmax(token_logits, dim=-1)
                    print(i_num, next_token_ids, f'txt')
                    hyps.append(next_token_ids.item())
                    token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
                    if next_token_ids == self.eos_token_id:
                        txt_finished = True
                        hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
                        print("hyps_text:", hyps_text)
                        print("text is over")
                        break
                    txt_list.append(next_token_ids.item())
                hyps_text = self.tokenizer.decode(txt_list, skip_special_tokens=True, add_special_tokens=False)
                print("hyps_text:", hyps_text)
                inferring_txt = False
            elif not speeech_finished:
                for i_speech in range(18):
                    i_num += 1
                    if i_num > 300:
                        break
                    llm_out = self.llama_model(
                        inputs_embeds=token_emb,
                        past_key_values=cache,
                        output_hidden_states=True
                    )
                    cache = llm_out.past_key_values
                    hidden_states = llm_out.hidden_states[-1]
                    token_logits = self.speech_head(hidden_states)
                    next_token_ids = self._sampler(
                        token_logits,
                        temperatures_tensor,
                        top_ps_tensor,
                        top_ks_tensor,
                    )
                    # next_token_ids = torch.argmax(token_logits, dim=-1)
                    hyps.append(next_token_ids.item())
                    print(i_num, next_token_ids, f'speech')
                    token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
                    if next_token_ids == 4096:
                        speech_eos_num += 1
                        print(f'遇到 4096')
                        if speech_eos_num >= 2:
                            print("speech is over")
                            speeech_finished = True
                        break
                    token_list.append(next_token_ids.item())
                    inferring_txt = True
            if i_num > 300:
                break
        return [hyps_text + "|" + str(token_list)]

    def generate_text2text(
            self,
            device,
            text,
    ):
        self.llama_model.eval()
        # labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=device)
        # labels = text[:,:]
        labels = self.tokenizer(
            [text],
            return_tensors="pt",
            add_special_tokens=False
        ).to(
            self.embed_tokens.weight.device).input_ids  # (1, L)
        labels = labels.to(device)
        labels_lengths = torch.tensor([len(labels[0])], dtype=torch.int64, device=device)
        # print(f'label_lengths:{labels_lengths}')
        # print(f'labels:{labels}')
        labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
        labels = labels.masked_fill(labels_pad_mask, 0)
        speech_embeds = self.embed_tokens(labels)  # B, L, D
        speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
            speech_embeds.device)
        speech_masks = ~labels_pad_mask
        # speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 + self.speech_token_num,
        #                                                                1 + self.speech_token_num,
        #                                                                speech_embeds, speech_masks, speech_target)

        qwen_instruct_prompt_pattern_1 = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
        prompt_pattern1 = self.tokenizer([qwen_instruct_prompt_pattern_1] * len(speech_embeds), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern1_embeds = self.embed_tokens(prompt_pattern1)

        qwen_instruct_prompt_pattern_2 = "<|im_end|>\n<|im_start|>assistant\n"
        prompt_pattern2 = self.tokenizer([qwen_instruct_prompt_pattern_2] * len(speech_embeds), return_tensors="pt"
                                         )['input_ids'].to(speech_embeds.device)
        prompt_pattern2_embeds = self.embed_tokens(prompt_pattern2)
        embeds = torch.cat([prompt_pattern1_embeds, speech_embeds, prompt_pattern2_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)

        if self.embed_tokens.weight.dtype == torch.float16 or self.embed_tokens.weight.dtype == torch.bfloat16:
            # utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            # embeds = embeds.to(torch.float16)
            embeds = embeds.to(torch.bfloat16)
            atts = atts.to(torch.bfloat16)
        outputs = self.llama_model.generate(
            inputs_embeds=embeds,
            max_new_tokens=self.max_length,
            num_beams=self.num_beams,
            do_sample=self.do_sample,
            min_length=self.min_length,
            top_p=self.top_p,
            top_k=self.top_k,
            repetition_penalty=1.2,
            length_penalty=1.2,
            temperature=self.temperature,
            attention_mask=atts,
            eos_token_id=self.eos_token_id,
            pad_token_id=-100,
        )
        output_text = self.tokenizer.batch_decode(outputs, add_special_tokens=False, skip_special_tokens=True)
        # output_text = [item.replace('<|endoftext|>', '') for item in output_text]
        return output_text

    def _get_embedding_from_wav(self, wavs, wavs_len):
        """
        return:
        wav_embedding: (b, l, v)
        wav_mask:  (b, l), wav为有效值的位置为true
        """
        encoder_out, encoder_mask = self.encoder(wavs, wavs_len)

        speech_embeds, encoder_mask = self.down_sample_2(encoder_out, encoder_mask)
        if self.speech_transformer is not None:
            filled_wavs_len = encoder_mask.squeeze(1).sum(-1)
            speech_embeds, encoder_mask = self.speech_transformer(speech_embeds, filled_wavs_len)
            # if rank == 0:
            #     utils_file.logging_limit_print(
            #         f'out of link shape: {speech_embeds.shape},encoder的第一帧的前20个数字：\n {speech_embeds[0][0][:20]}')

            # utils_file.logging_limit_print(
            #     'get_embedding_from_wav(): speech_embeds.shape,by  self.speech_transformer(speech_embeds, speech_lens):',
            #     speech_embeds.shape)
            speech_embeds = self.speech_llama_proj(speech_embeds)
            # if rank == 0:
            #     utils_file.logging_limit_print(
            #         f'out of speech_llama_proj shape: {speech_embeds.shape},encoder的第一帧的前20个数字：\n {speech_embeds[0][0][:20]}')

        # utils_file.logging_limit_print(
        #     'get_embedding_from_wav(): speech_embeds.shape,by  self.speech_llama_proj(speech_embeds):',
        #     speech_embeds.shape)

        return speech_embeds, encoder_mask.squeeze(1)

    def _get_embedding_from_text(self, text):
        """
        将字符串先量化，再转成词向量

        Args:
            text: str

        Returns:
            text_embeds: (1, L, D)

        """
        text_id = self.tokenizer(
            text,
            return_tensors="pt",
            add_special_tokens=False
        ).to(
            self.embed_tokens.weight.device).input_ids
        text_embeds = self.embed_tokens(text_id)
        text_embeds_len = torch.tensor([text_embeds.size(1)], dtype=torch.long)
        return text_embeds, text_embeds_len

    def _add_bos_eos(self, bos, eos, inputs_embeds, attention_mask, target=None):
        B = len(inputs_embeds)
        bos_eos_target = torch.full([B, 1], self.IGNORE_ID).to(inputs_embeds.device)  # B,1
        bos_eos_mask = torch.full([B, 1], True).to(inputs_embeds.device)  # B, 1

        if bos is not None:
            bos_embed = self.speech_token_emded(torch.full([B, 1],
                                                           bos).to(inputs_embeds.device))  # B, 1, D
            inputs_embeds = torch.cat((bos_embed, inputs_embeds), 1)  # B, (1+T), D
            attention_mask = torch.cat((bos_eos_mask, attention_mask), 1)  # B, (1+T)
            if target is not None:
                target = torch.cat((bos_eos_target, target), 1)  # B, (1+T), D

        if eos is not None:
            eos_embed = self.speech_token_emded(torch.full([B, 1],
                                                           eos).to(inputs_embeds.device))  # B, 1, D
            inputs_embeds = torch.cat((inputs_embeds, eos_embed), 1)  # B, (1+T+1), D
            attention_mask = torch.cat((attention_mask, bos_eos_mask), 1)  # B, (1+T+1)
            if target is not None:
                target = torch.cat((target, bos_eos_target), 1)  # B, (1+T+1), D

        return inputs_embeds, attention_mask, target


    def infer_sample_teach_force(
            self,
            wavs,
            wavs_len,
            prompt,
            text,
            speech_token,
    ):
        labels_lengths = torch.tensor([len(text[0])], dtype=torch.int64, device=wavs.device)
        labels = text[:, :]
        labels_pad_mask = make_pad_mask(labels_lengths)  # B, L
        labels = labels.masked_fill(labels_pad_mask, 0)
        speech_embeds = self.embed_tokens(labels)  # B, L, D
        speech_target = torch.full(labels_pad_mask.shape, self.IGNORE_ID).to(
            speech_embeds.device)
        speech_masks = ~labels_pad_mask
        speech_embeds, speech_masks, speech_target = self._add_bos_eos(0 +self.speech_token_num,
                                                                       1 +self.speech_token_num,
                                                                       speech_embeds, speech_masks, speech_target)

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)
        embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)
            atts = atts.half()
        device = wavs.device
        inputs_embeds = embeds.to(device)

        speech_token_list = speech_token[0].tolist()
        speech_token_list_len = len(speech_token_list)
        print(f'speech_token_list_len:{speech_token_list_len}')
        max_len = 200
        beam = 3
        beam_size = beam
        running_size = beam
        output_token = []
        hyps = [self.speech_token_num - 1]

        scores = [1.0]
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )
        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1.0
        cache = llm_out.past_key_values
        token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        repetition_penalty = 1.1
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)

        for i in range(max_len):
            llm_out = self.llama_model(
                inputs_embeds=token_emb,
                past_key_values=cache,
                output_hidden_states=True
            )
            cache = llm_out.past_key_values
            hidden_states = llm_out.hidden_states[-1]
            token_logits = self.speech_head(hidden_states)
            # probs =  F.log_softmax(token_logits[:,-1], dim=-1)[0]
            next_token_ids = self._sampler(
                token_logits,
                temperatures_tensor,
                top_ps_tensor,
                top_ks_tensor,
            )
            print(i, next_token_ids)
            if next_token_ids == self.speech_token_num - 1:
                print("break la!")
                print("hyps:", hyps)
                break
            hyps.append(next_token_ids.item())
            token_emb = self.speech_token_emded(torch.tensor(speech_token_list[i]).to(device)).unsqueeze(0)
        res = []
        for i in hyps[1:]:
            # if i != self.speech_token_num-1:
            res.append(i)
        print(res)
        return [res]

    def _sampler(
            self,
            logits: torch.Tensor,
            temperatures: Union[torch.Tensor, None],
            top_ps: torch.Tensor,
            top_ks: torch.Tensor,
    ) -> torch.Tensor:
        """
        Sample from logits.
        Args:
            logits: （1，1，vocab_size）
            temperatures:
            top_ps:
            top_ks:

        Returns:

        """
        print(f'logits:{logits.shape}')
        assert logits.size(1) == 1
        logits = logits.squeeze(1)  # (batch_size, vocab_size)
        if temperatures is None:
            return torch.argmax(logits, dim=-1).squeeze(dim=-1)
        # Apply temperature scaling.
        logits.div_(temperatures.unsqueeze(dim=1))
        # Calculate probabilities with softmax.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        # Apply top-p, top-k.
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
        probs_sort = torch.where(top_ps_mask, 0, probs_sort)

        top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
        top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
        top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
        probs_sort = torch.where(top_ks_mask, 0, probs_sort)
        # Re-normalization.
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        probs = torch.gather(probs_sort,
                             dim=-1,
                             index=torch.argsort(probs_idx, dim=-1))
        next_token_ids = torch.multinomial(probs, num_samples=1,
                                           replacement=True).squeeze(dim=-1)
        return next_token_ids

    def infer_sample4speech2text_token_teacher_force(
            self,
            wavs,
            wavs_len,
            prompt,
            speech_token=None,
            answer_text=None,
    ):
        self.llama_model.eval()
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num,
                                                           speech_embeds, speech_masks, None)

        device = speech_embeds.device

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)
        text_token = self.tokenizer([answer_text], return_tensors="pt"
                                    )['input_ids'].to(speech_embeds.device)
        text_token_embeds = self.embed_tokens(text_token)
        embeds = torch.cat([prompt_embeds, speech_embeds, text_token_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)
            atts = atts.half()
        inputs_embeds = embeds.to(speech_embeds.device)

        max_len = 150
        beam = 3
        beam_size = beam
        running_size = beam
        output_token = []
        hyps = [self.speech_token_num]
        hyps_text = ""

        scores = [1.0]
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )
        # speech_token_list = speech_token[0]
        # speech_token_list_len = len(speech_token_list)
        if speech_token is not None:
            print(f'speech_token_list_len:{len(speech_token[0])}')
            print(f'speech_token:{speech_token[0]}')

        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1.2
        cache = llm_out.past_key_values
        token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        repetition_penalty = 1.1
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        is_speech_token = False
        speech_eos_num = 0
        for i in range(max_len):
            llm_out = self.llama_model(
                inputs_embeds=token_emb,
                past_key_values=cache,
                output_hidden_states=True
            )
            cache = llm_out.past_key_values
            hidden_states = llm_out.hidden_states[-1]
            token_logits = self.speech_head(hidden_states)
            # probs =  F.log_softmax(token_logits[:,-1], dim=-1)[0]
            # if i ==2 or i == 80:
            #     torch.save(probs, f'probs_{i}.pt')
            next_token_ids = self._sampler(
                token_logits,
                temperatures_tensor,
                top_ps_tensor,
                top_ks_tensor,
            )
            print(i, next_token_ids, f'is_speech_token:{is_speech_token}')
            if next_token_ids == self.speech_token_num - 1:
                print(f'遇到 4096')
                break
            hyps.append(next_token_ids.item())
            # if 1+i >= len(speech_token[0]):
            #     break
            # token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0)
            token_emb = self.embed_tokens(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        res = []
        for i in hyps[1:]:
            # if i != self.speech_token_num-1:
            res.append(i)
        print(res)
        return [answer_text + str(res[2:])]

    def infer_sample4speech2text_token_teacher_force2(
            self,
            wavs,
            wavs_len,
            prompt,
            speech_token=None,
            answer_text=None,
    ):
        self.llama_model.eval()
        speech_embeds, speech_masks = self._get_embedding_from_wav(wavs, wavs_len)
        speech_embeds, speech_masks, _ = self._add_bos_eos(0 +self.speech_token_num, 1+self.speech_token_num ,
                                                           speech_embeds, speech_masks, None)

        device = speech_embeds.device

        prompt = self.tokenizer([prompt], return_tensors="pt"
                                )['input_ids'].to(speech_embeds.device)
        prompt_embeds = self.embed_tokens(prompt)
        text_token = self.tokenizer([answer_text], return_tensors="pt"
                                    )['input_ids'].to(speech_embeds.device)
        # text_token_embeds = self.embed_tokens(text_token)
        embeds = torch.cat([prompt_embeds, speech_embeds], dim=1)
        atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
        if self.embed_tokens.weight.dtype == torch.float16:
            utils_file.logging_limit_print('generate(): self.embed_tokens.weight.dtype == torch.float16')
            embeds = embeds.to(torch.float16)
            atts = atts.half()
        inputs_embeds = embeds.to(speech_embeds.device)

        max_len = 150
        beam = 3
        beam_size = beam
        running_size = beam
        output_token = []
        hyps = [self.speech_token_num - 1]
        hyps_text = ""

        scores = [1.0]
        llm_out = self.llama_model(
            inputs_embeds=embeds,
            past_key_values=None,
            output_hidden_states=True
        )
        # speech_token_list = speech_token[0]
        # speech_token_list_len = len(speech_token_list)
        if speech_token is not None:
            print(f'speech_token_list_len:{len(speech_token)}')
            print(f'speech_token:{speech_token}')

        batch_size = 1
        top_k = 10
        top_p = 0.9
        temperature = 1.2
        cache = llm_out.past_key_values
        token_emb = self.speech_token_emded(torch.tensor(hyps[-1:]).to(device)).unsqueeze(0)
        repetition_penalty = 1.1
        top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
        top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
        temperatures_tensor = None if not temperature else torch.FloatTensor(
            [temperature] * batch_size).to(device)
        is_speech_token = False
        speech_eos_num = 0
        token_num = len(speech_token)
        for i in range(token_num):
            llm_out = self.llama_model(
                inputs_embeds=token_emb,
                past_key_values=cache,
                output_hidden_states=True
            )
            cache = llm_out.past_key_values
            hidden_states = llm_out.hidden_states[-1]
            token_logits = self.speech_head(hidden_states)
            # probs =  F.log_softmax(token_logits[:,-1], dim=-1)[0]
            # if i ==2 or i == 80:
            #     torch.save(probs, f'probs_{i}.pt')
            next_token_ids = self._sampler(
                token_logits,
                temperatures_tensor,
                top_ps_tensor,
                top_ks_tensor,
            )
            print(i, next_token_ids, f'is_speech_token:{is_speech_token}')
            if next_token_ids == self.speech_token_num - 1:
                print(f'遇到 4096')
                break
            hyps.append(next_token_ids.item())
            # if 1+i >= len(speech_token[0]):
            #     break
            # token_emb = self.speech_token_emded(torch.tensor([speech_token[0][i+1]]).to(device)).unsqueeze(0)
            token_emb = self.embed_tokens(torch.tensor([speech_token[i]]).to(device)).unsqueeze(0)
        res = []
        for i in hyps[1:]:
            # if i != self.speech_token_num-1:
            res.append(i)
        print(res)
        return [hyps_text + str(res)]